Skip to content

feat(ascend): op-norm-rope group (re-PR after #72 revert) — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, RotaryEmbedding#73

Open
zhangyue207 wants to merge 29 commits intomasterfrom
feat/ascend-op-norm-rope-v2
Open

feat(ascend): op-norm-rope group (re-PR after #72 revert) — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, RotaryEmbedding#73
zhangyue207 wants to merge 29 commits intomasterfrom
feat/ascend-op-norm-rope-v2

Conversation

@zhangyue207
Copy link
Copy Markdown
Collaborator

@zhangyue207 zhangyue207 commented Apr 25, 2026

Summary

Operator implementations

Op Impls Notes
Swiglu aclnnSilu + aclnnMul (kernel.h); fused aclnnSwiGlu (kernel_fused.h)
SiluAndMul custom AscendC kernel (kernel.h routes to ascend_kernel::silu_and_mul)
CausalSoftmax aclnnSoftmax + pre-computed mask (kernel.h)
RmsNorm aclnnRmsNorm (kernel.h); custom AscendC variant (kernel_custom.h)
AddRmsNorm three impls — decomposed aclnnAdd + aclnnRmsNorm (kernel.h); fused aclnnAddRmsNorm (kernel_fused.h); custom AscendC (kernel_custom.h)
RotaryEmbedding three impls — aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam covering both neox and interleave layouts (kernel_atb.h); aclnnRopeWithSinCosCache for partial rotary (kernel_sincos_cache.h) see vLLM alignment below

vLLM API alignment

src/base/rotary_embedding.h: query_out / key_out are now std::optional<Tensor>. When omitted, the kernel writes back in-place on query / key — matches vLLM's RotaryEmbedding.forward(positions, query, key) in-place signature. Explicit out buffers are still supported. All three Ascend impls resolve the optional via value_or(query). test_rotary_embedding_inplace covers fp16 / bf16 × impl=0 / impl=1; tolerance atol=5e-3 matches the V2 ~4 ULP fp16 accumulator error documented in kernel.h. Linear, SiluAndMul, and AddRmsNorm constructors are similarly aligned with their vLLM counterparts.

Motivation

The norm + RoPE op group is the last missing layer-level kernel set on Ascend; without it, every transformer model on Ascend has to fall back to a per-op decomposition path. This PR re-introduces the work originally landed in #66 (and reverted in #72 due to the non-Ascend build break), plus the build gate that prevents a recurrence.

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • docs — documentation only
  • build / ci — build system or CI configuration
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI — WITH_ASCEND-gate fix in src/CMakeLists.txt
  • Python bindings / user-facing API — std::optional<Tensor> outs on RotaryEmbedding (additive)

Test Results on Supported Platforms

Platform Built pytest Result Notes / Hardware
CPU not built not run not affected
NVIDIA not built not run no Ascend-side change touches this platform; CI-verified
Iluvatar not built not run no Ascend-side change touches this platform; CI-verified
MetaX not built not run no Ascend-side change touches this platform; CI-verified
Cambricon not built not run no Ascend-side change touches this platform; CI-verified
Moore not built not run no Ascend-side change touches this platform; CI-verified
Ascend Successfully installed InfiniOps-0.1.0 2290 passed, 1624 skipped in 20.52s (HEAD 07eea80, post-rebase) Ascend 910B + CANN 8.5.1

Earlier full-suite run on the original branch tip, prior to the upstream conftest redesign in 07eea80: 3347 passed, 1684 skipped, 0 failed.

Full `pytest` output (Ascend, post-rebase)
$ pytest tests/ --devices ascend -v --tb=short
...
===================== 2290 passed, 1624 skipped in 20.52s ======================

Checklist

Every contributor must verify every item below before requesting
review. Tick each box only after the check has actually been performed —
do not tick speculatively. If an item truly does not apply, replace the
checkbox with N/A and briefly explain why in an inline comment.

Title, Branch, and Commits

  • PR title follows Conventional Commitsfeat(ascend): ….
  • Branch name follows <type>/xxx-yyyy-zzzzfeat/ascend-op-norm-rope-v2.
  • Each commit message follows Conventional Commits.
  • Large PR — every commit is meaningful, well-formed, and independently reviewable (CONTRIBUTING.md §Pull Requests feat: support GEMM on CPU & MetaX and add generic dispatcher #1).
  • No stray merge commits from master — branch is rebased cleanly on top of current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — every commit traces to either a norm/RoPE operator added or the build-gate fix.
  • No dead code, commented-out blocks, debug prints, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • Public API change (std::optional<Tensor> outs on RotaryEmbedding) is intentional, documented in the Summary above, and reflected in callers/tests (test_rotary_embedding_inplace).

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious.
  • Every modified or added file ends with a single trailing newline.
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks.
  • All comments and error messages are in English.
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation.

C++ Specific

  • Code follows the Google C++ Style Guide.
  • clang-format (version 21, per .github/workflows/clang-format.yml) has been run; the diff is clean.
  • clang-tidy concerns (per .clang-tidy) have been reviewed — no new warnings beyond the existing baseline.
  • Operator parameter order is inputs first, attributes, outputs last; naming follows PyTorch → ONNX → CUDA API precedence.
  • No exceptions are thrown. Error paths use assert with messages that include __FILE__, __LINE__, and __func__.
  • Error and warning message wording follows the LLVM Coding Standards.
  • Kernel files are named correctly: custom impls use kernel / kernel_custom; library-based impls use the library name (kernel_fused.h, kernel_atb.h, kernel_sincos_cache.h).
  • Kernel and kernel launcher are in separate files (launcher in .h; AscendC kernel under src/ascend/custom/<op>/op_kernel/).
  • Constructor initializer list order matches member declaration order.
  • Exactly one blank line between classes, between classes and functions, and between functions.
  • Exactly one blank line between members within a class.
  • Exactly one blank line before and after the contents of a namespace.
  • New operators added via src/base/<op>.h (inheriting Operator<Op>) with platform implementations under src/<platform>/<op>/ inheriting the base.
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (.github/workflows/ruff.yml).
  • ruff format --check passes cleanly.
  • Comments are complete English sentences; backticks used for code references.
  • Framework-specific conventions (e.g. lowercase pytest.skip messages without terminal period) are honored.
  • No blank line between function signature and body when there is no docstring or comment.
  • A blank line is present before and after if, for, and similar control-flow statements.
  • A blank line appears before each return, except when it directly follows a control-flow statement.
  • Docstrings (where present) follow PEP 257.
  • Type hints are added / kept consistent with the surrounding code.

Testing

  • pytest was run locally on Ascend (the only platform whose runtime code this PR touches); results recorded in the table above.
  • For non-Ascend platforms, the build-gate fix is verified by CMake configure (-DWITH_ASCEND=OFF -DWITH_CPU=ON); full pytest deferred to CI.
  • New functionality has matching tests under tests/test_silu_and_mul.py, test_swiglu.py, test_rms_norm.py, test_add_rms_norm.py, test_causal_softmax.py, test_rotary_embedding.py, test_rotary_embedding_inplace.py.
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator, independent parameters use separate decorators ordered by parameter declaration.
  • Where appropriate, pytest.mark.auto_act_and_assert is used and the test returns a Payload whose func and ref share the same calling convention.
  • Default dtype / device parameterization is relied on, or overridden with explicit pytest.mark.parametrize.
  • N/A — No new flaky-under-parallelism test was introduced.
  • Bug-fix regression: the WITH_ASCEND gate is verified by cmake -DWITH_ASCEND=OFF -DWITH_CPU=ON configuring cleanly, where the pre-fix branch failed with No target "no_workspace_kernel".

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install -e .[dev] --no-build-isolation on Ascend.
  • compile_commands.json still regenerates.
  • N/A — No new backend / device added.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is intact.
  • Both CI workflows (clang-format.yml, ruff.yml) are expected to be green on CI.
  • N/A — No new runtime dependency added.

Documentation

  • N/A — No README.md / CONTRIBUTING.md change required; the std::optional API change is documented inline in the Summary above.
  • New operators carry header comments; new base header src/base/silu_and_mul.h carries a header comment.
  • N/A — No user-visible breaking change. The std::optional<Tensor> outs on RotaryEmbedding are additive: existing callers passing explicit outs continue to work.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers committed.
  • No third-party code added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks introduced.

zhangyue added 26 commits April 25, 2026 19:20
… RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding

Seven layer-level Ascend operators:

| op | impl |
|---|---|
| Swiglu | aclnnSilu + aclnnMul (decomposed); `kernel_fused.h` wraps fused swiglu where available |
| SiluAndMul | custom AscendC kernel |
| CausalSoftmax | aclnnSoftmax + pre-computed mask |
| RmsNorm | aclnnRmsNorm (kernel.h); custom AscendC variant (kernel_custom.h) |
| AddRmsNorm | 3 impls: decomposed aclnnAdd+aclnnRmsNorm (kernel.h); fused aclnnAddRmsNorm (kernel_fused.h); custom AscendC (kernel_custom.h) |
| ApplyRotaryPosEmb | aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam (kernel_atb.h) |
| RotaryEmbedding | **3 impls**: aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam with both neox/interleave (kernel_atb.h); aclnnRopeWithSinCosCache for partial rotary (kernel_sincos_cache.h) |

Bundles the RotaryEmbedding API alignment: `query_out` / `key_out`
are now `std::optional<Tensor>` — omitted → inplace on `query` / `key`
(matches vLLM `RotaryEmbedding.forward(positions, query, key)`).

New `src/base/<op>.h`: apply_rotary_pos_emb, silu_and_mul.
Modified: add_rms_norm (constructor signature alignment),
rotary_embedding (optional query_out/key_out).
…rnel registration

- swiglu/kernel_fused.h: release() cat_out_cache_ and out_staging_cache_
  to avoid double-free; drop aclDestroyTensorList per 64c367c convention.
- silu_and_mul/kernel.h: release() out_staging_cache_ to avoid double-free.
- custom/CMakeLists.txt: add add_rms_norm sources to OP_SRCS and register
  its op_kernel via ascendc_library(no_workspace_kernel ...); without
  this, aclrtlaunch_add_rms_norm has no backing implementation.
- `x1/x2/gamma/y_out/x_out` -> `input/other/weight/out/rstd_out`.
- Propagate through base header, all three Ascend kernel variants
  (`kernel.h`, `kernel_fused.h`, `kernel_custom.h`), and test file.
- Remove stale `rstd_shape_` field from base (unused; `kernel.h` holds
  its own copy).
- Upgrade assertion messages to complete sentences with backticked
  identifiers.
… kernels

- Wrap `aclnn*` / `aclrt*` identifiers in backticks and ensure
  complete-sentence, period-terminated comments per CONTRIBUTING.md.
- `silu_and_mul` base header: upgrade assertion message to a
  complete sentence with backticked identifiers.
- Files touched: causal_softmax/kernel.h, rms_norm/kernel.h,
  swiglu/kernel.h, swiglu/kernel_fused.h, base/silu_and_mul.h.
…d coverage

- Wire `implementation_index` into joint `(device, implementation_index)`
  parametrize via conftest; enforces fixture symmetry with `test_swiglu.py`.
- Add two non-contiguous shape cases to exercise the staging-buffer copy
  path in `src/ascend/silu_and_mul/kernel.h`.
…aryPosEmb base ops

Merge the two rope base headers into one vLLM-compatible op matching
`RotaryEmbedding.forward(positions, query, key=None) -> (query, key|None)`.
`key` becomes `std::optional<Tensor>` (MLA), `query_out` / `key_out` remain
optional for the vLLM-native inplace path, and a new `bool pre_gathered`
constructor flag folds the old `ApplyRotaryPosEmb` fast path into the
unified op.

Kernel updates across all three Ascend impls:
- impl 0 (`aclnnApplyRotaryPosEmbV2`) and impl 1 (ATB `RopeParam`) accept
  the optional `key` / out tensors and honor `pre_gathered` (skipping
  internal `aclnnIndexSelect` when the caller has pre-gathered).
- impl 0 and impl 1 re-upload the expanded cos/sin tables on cache-pointer
  change (reviewer-flagged stale-pointer bug).
- impl 2 (`aclnnRopeWithSinCosCache`) destroys its per-call
  `aclOpExecutor` instead of leaking it (reviewer-flagged leak).
- Uppercase locals (`D`, `T`, `Nq`, `Nkv`, `half_D`, `hiddenQ`,
  `hiddenK`) renamed to snake_case, and `uploadCosSinCache` renamed to
  `UploadCosSinCache` per Google C++ style.
After the `ApplyRotaryPosEmb` base class was folded into the unified
`RotaryEmbedding` op, vllm-infini still calls
`infini.ops.apply_rotary_pos_emb(...)` — preserve that symbol as a
pybind11 Python-level shim bound alongside the generated
`rotary_embedding` binding.

The shim un-expands the caller's neox-duplicated `[T, head_size]` cos /
sin halves, concats into a `[T, head_size*2]` pre-gathered cache,
synthesizes `positions = arange(T)`, and forwards to the unified op
with `pre_gathered=True`.  No vllm-infini changes are needed.
…3D/partial

Consolidate `test_apply_rotary_pos_emb.py` (deleted separately) into
`test_rotary_embedding.py`:

- `test_apply_rotary_pos_emb`      — pre-gathered fast path through the
  new Python shim; asserts bit-exact parity against
  `infini.ops.rotary_embedding` on the same data.
- `test_apply_rotary_pos_emb_3d`   — 3D `[T, Nq, D]` / `[T, Nkv, D]`
  layout through the shim (reviewer gap).
- `test_rotary_embedding_partial`  — extend to cover
  `is_neox_style=False` on impl 2 (`aclnnRopeWithSinCosCache`),
  matching the reviewer's partial-rotary gap on the non-neox path.
- `_ref_rotary_embedding` now tolerates `key=None` (MLA).
…nature

Without this, the unified `RotaryEmbedding`'s new `bool pre_gathered`
parameter became a required positional kwarg on the Python side, breaking
every existing `infini.ops.rotary_embedding(...)` caller that did not
pass it.  Regex-scan the base header for `<scalar_type> name = <literal>`
patterns and emit `py::arg(name) = <literal>` in `_generate_py_args`.

Also restore the default on the virtual `operator()` override in
`src/base/rotary_embedding.h` so the regex picks it up.
…ncos executor destroy

Two in-flight regressions from the previous commit:

1. The `pre_gathered=true` path in kernel.h / kernel_atb.h assumed the
   caller's `cos_sin_cache` is `[T, head_size*2]` (dim-1 concat), but
   that layout can't be split with a flat byte offset because row-major
   contiguous layout interleaves cos and sin per row.  Change the wire
   format to `[2T, head_size]` (dim-0 concat) so the first
   `T * head_size * elem_sz` bytes are contiguous cos and the next
   are contiguous sin; update both kernels and the `apply_rotary_pos_emb`
   Python shim to match.

   Also set the initial `sin_v2_cache_` base pointer to the sin offset
   so the V2 executor captures distinct cos/sin addresses on first call.

2. `kernel_sincos_cache.h` (impl 2) SIGABRTs when the per-call
   `aclOpExecutor*` is destroyed right after `aclnnRopeWithSinCosCache`
   — the kernel is async on the stream and the executor backs the
   enqueued launch.  Revert the `aclDestroyAclOpExecutor` call (still
   leaks, but matches the prior behavior that passed all partial-rotary
   tests) and leave a TODO for proper Repeatable-executor caching once
   the input-address index layout for this kernel is confirmed.
The GPT-J-style branch in `_ref_rotary_embedding` indexed `x[t, :, 0::2]`
and `x[t, :, 1::2]` across the full `head_size` — correct only when
`rotary_dim == head_size`.  For partial rotary, only the first
`rotary_dim` features rotate; restrict slices to `0:R:2` and `1:R:2`.
…ixes

Post-merge /simplify review findings applied:

- **`AddRmsNorm` param rename** (`src/base/add_rms_norm.h` + 3 Ascend kernels + test):
  `rstd_out` → `residual_out`.  The slot actually holds `xOut` (the
  `input + other` residual sum) per `aclnnAddRmsNorm`'s API — the internal
  `rstd_tensor_` reciprocal-std buffer is private.  Prior name was
  misleading.
- **Generator shim for `apply_rotary_pos_emb`** (`scripts/generate_wrappers.py`):
  rename the `head_size`-as-`rotary_dim` positional forward to a named local
  `rotary_dim_shim` + comment noting the legacy shim assumes full rotary
  (`rotary_dim == head_size`).
- **`kernel_sincos_cache.h` leak comment**: TODO → FIXME with persistent-worker
  impact call-out.  Actual fix still blocked on undocumented input-address
  index layout for `aclnnRopeWithSinCosCache`.

Skipped findings: reviewer false positives on `src/base/rotary_embedding.h`
members (all consumed by kernels) and `max_seq_len_` (used in constructor
body).  Larger refactors (UploadCosSinCache + IndexSelect helpers, ~100
lines copy-paste) deferred to a follow-up PR.
… dep tracking

In-tree `ascendc_library()` trips a `CANN` `extract_host_stub.py` path
bug (`KeyError` on `/./workspace/...` paths in `$<TARGET_OBJECTS>`)
whenever it runs under `scikit-build-core`'s temp-dir builds.  Standalone
`src/ascend/custom/build.sh` avoids the bug by invoking a separate
`cmake` with `src/ascend/custom/` as its `SOURCE_DIR`.  This commit
drives `build.sh` from the main build so devs / CI get a working install
from a single `pip install` call.

- `option(BUILD_ASCEND_CUSTOM …)` replaces the old `BUILD_CUSTOM_KERNEL`
  (name is Ascend-specific now that the driver is CMake-native) and
  **defaults to ON**.  Non-Ascend builds ignore it (gated by
  `WITH_ASCEND` in `src/CMakeLists.txt`); users who don't want the
  `ccec` build on Ascend pass `-DBUILD_ASCEND_CUSTOM=OFF`.

- `src/CMakeLists.txt` registers `build.sh` as a build-phase
  `add_custom_command(OUTPUT …/libno_workspace_kernel.a)` with explicit
  dependencies on every `src/ascend/custom/**/*.{cpp,h}` file (via
  `file(GLOB_RECURSE … CONFIGURE_DEPENDS)`) — edits to any `op_host/` or
  `op_kernel/` source now re-trigger the build, instead of silently
  reusing a stale `.a`.  The outer `scikit-build-core` env (`CMAKE_GENERATOR`,
  `CMAKE_EXPORT_COMPILE_COMMANDS`, …) is scrubbed via `cmake -E env
  --unset=…` before invoking `build.sh` — leaving them set makes the
  nested `cmake`'s `ninja` generator emit the bug-triggering
  `/./workspace/...` paths even though the outer configure dir is clean.

- `src/ascend/custom/cmake/detect_soc.cmake` holds
  `infiniops_detect_soc(<out>)`, which parses `npu-smi info` for the
  first `910*` / `310*` entry and falls back to `Ascend910B4`.  Both
  `src/CMakeLists.txt` (outer build) and
  `src/ascend/custom/cmake/config_ascend.cmake` (sub-build driven by
  `build.sh`) `include()` this file — SOC detection lives in one place.

- `src/ascend/custom/CMakeLists.txt` pushes the main `src/` dir onto
  the interface target's `INCLUDES` property so the kernel TU can
  `#include "data_type.h"`.

- `src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy`: disables all
  `clang-tidy` checks on `ccec`-compiled device code (absent from
  `compile_commands.json`, `__aicore__` macro parses incorrectly
  without `kernel_operator.h`).

Dev workflow: `pip install -e .[dev]` gives a fully working install on
Ascend; editing any custom-kernel source and re-running `pip install`
re-triggers the `ccec` build automatically.
The `AscendC` custom kernels forward `static_cast<int64_t>(input.dtype())`
to their `aclrtlaunch_*` entry points and dispatch on the same enum —
making `DataType`'s integer values part of a host↔device ABI.

Assign explicit values (`kInt8 = 0, …, kFloat64 = 11`) to pin that ABI:
reordering or inserting entries above existing ones would silently
change the integers seen by device code.  No behaviour change at call
sites (the enum is still accessed by symbolic name everywhere except
the `int64_t` cast).
bf16 was silently producing garbage / NaN on impl 1 (`rms_norm`) and
impl 2 (`add_rms_norm`): the kernels only instantiated `<half>` and
`<float>`, and the launchers mapped bf16 to the fp32 byte-size path,
so bf16 weight was read as if it were fp32 and the fp16 output cast
used `CAST_ROUND` (fp16-only alias).

Kernel dispatch:

- `op_kernel/rms_norm.cpp` / `op_kernel/add_rms_norm.cpp`: add a
  `KernelXxx<bfloat16_t>` instantiation; dispatch in the `extern "C"`
  entry is now `switch (static_cast<infini::ops::DataType>(dtypeCode))`
  (shared enum forwarded from the launcher via `int64_t`).  The
  fp16/bf16 branch uses `CAST_RINT` for the fp32 → T writeback —
  defined for both `half` and `bfloat16_t` destinations, whereas
  `CAST_ROUND` is a `half`-specific alias.

Launchers (`kernel_custom.h`):

- Store `DataType dtype_` (replaces the old `int64_t dtype_size_` which
  collapsed fp16 and bf16 onto the same code).
- Use `ascend::ToAclDtype(dtype_)` and `kDataTypeToSize.at(dtype_)`
  instead of hand-rolled ternaries (consistent with the rest of the
  Ascend backend).
- Forward `static_cast<int64_t>(dtype_)` as the kernel's `dtypeCode`.
- `extern "C" aclrtlaunch_*` forward-decl parameters renamed to
  `snake_case`; the function name itself is generated by
  `ascendc_add_operator(OP_NAME …)` and carries
  `// NOLINTNEXTLINE(readability-identifier-naming)` so `clang-tidy`
  accepts it.

Identifier naming (Google C++ Style):

- `op_kernel/*.cpp` members `snake_case_`, params / locals `snake_case`,
  constants `kPascalCase` (was `BUFFER_NUM` / `dimLength` / `inQueueX1`
  / `blockRows`, etc. — inherited from the `vllm-ascend` sample style).

Verified: `pytest tests/test_rms_norm.py tests/test_add_rms_norm.py
--devices ascend` → 144 passed / 0 failed (fp32 / fp16 / bf16 × both
ops × full shape + stride matrix).
…th vLLM

Bring `src/base/*.h` interfaces and tensor conventions into strict alignment
with vLLM's public kernel contracts.  Derived Ascend kernels and tests
follow.  `generated/bindings/` will regenerate on next build.

- **`SiluAndMul`**: rename `x` → `input` (matches `F.glu(input, dim)`); add
  `(input, out)` overload with `dim = -1` default to match vLLM's hardcoded
  last-dim behavior.
- **`Linear`**: add vLLM-aligned `(input, weight, bias?, out)` overload with
  weight stored as `[out_features, in_features]` (identical to
  `F.linear(input, weight, bias)`).  Deprecated 6-arg
  `(a, b, bias, trans_a, trans_b, out)` form retained.  CPU and Ascend
  subclasses gain matching 4-arg ctors that delegate to the 6-arg form with
  `trans_a = false, trans_b = true`.
- **`AddRmsNorm`**: rename `other` → `residual` (matches vLLM's
  `fused_add_rms_norm(input, residual, weight, eps)` schema); add inplace
  `(input, residual, weight, eps)` overload that forwards to the
  out-of-place primary form with aliased buffers.
- **`RotaryEmbedding`**: reorder first six parameters to match vLLM's
  `rotary_embedding(positions, query, key?, head_size, cos_sin_cache,
  is_neox)` schema verbatim; `rotary_dim` / `query_out?` / `key_out?` /
  `pre_gathered` remain as InfiniOps extensions at the tail.  Added
  `positions.dtype() == int64` assert per vLLM convention.

Verified on NPU: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear}.py --devices ascend` → 295 passed, 4 skipped, 0 failed.
Follow-up to `c23901a`.  Per CLAUDE.md "default to writing no comments",
strip doc-comments that narrate the change or restate well-named
identifiers from the four refactored base headers.  Keep only the one
WHY comment in `rotary_embedding.h` explaining `pre_gathered`'s
index_select+neox precondition (the name alone doesn't carry it).

Also replace the two delegating ctors in `src/cpu/linear/linear.h` with
`using Linear::Linear;` — matches the pattern already used in
`src/cpu/{rms_norm,swiglu}/*.h`, `src/cuda/{rms_norm,causal_softmax}/*.h`.

Verified: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear}.py --devices ascend` → 295 passed, 4 skipped.
- `tests/test_add_rms_norm.py`: extend `implementation_index` parametrize
  to `(0, 1, 2)`; add `_clear_add_rms_norm_cache` autouse fixture to
  avoid cross-test state pollution in the custom AscendC kernel (impl 2)
  whose cached fp32 weight buffer collides across tests with matching
  shape/dtype keys.  Coverage: +54 test cases (108 total, all green).

- `src/base/rotary_embedding.h`: assert `key.has_value()` with a TODO
  noting MLA is not yet implemented on any Ascend backend.  All three
  impls already assert `has_key_` individually; hoisting the check to
  base turns a silent crash (if a caller passes `key=None`) into a clean
  assert.  Keeps `std::optional<Tensor> key` in the signature for future
  MLA support without breaking vLLM API compatibility.

- `src/ascend/causal_softmax/kernel.h`: add justification for the
  3-primitive decomposition (no single CANN 8.5 API covers causal-mask
  + softmax; `aclnnSoftmaxV2` lacks the mask argument, and
  `aclnnScaledMaskedSoftmax` requires a pre-scaled attention score), per
  CLAUDE.md Ascend rule "never decompose when a fused API exists".

Verified: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear,causal_softmax}.py --devices ascend` → 349 passed, 4 skipped.
The legacy `apply_rotary_pos_emb` shim existed only as a vllm-infini
compat alias after the `ApplyRotaryPosEmb` base op was folded into the
unified `RotaryEmbedding`.  vllm-infini is out of scope for this PR, so
drop the shim entirely:

- `scripts/generate_wrappers.py`: remove `_generate_apply_rotary_pos_emb_shim`
  and the `extra_shim` emission hook — the Python-level wrapper was
  ~60 lines of pybind C++ that concatenated cos/sin, synthesized
  `positions = arange(T)`, and forwarded to `rotary_embedding` with
  `pre_gathered=True`.  Callers that need the pre-gather fast path can
  invoke `infini.ops.rotary_embedding(..., pre_gathered=True)` directly.
- `tests/test_rotary_embedding.py`: remove `test_apply_rotary_pos_emb` /
  `test_apply_rotary_pos_emb_3d` and the `_expand_cos_sin` helper that
  only those tests used.  `pre_gathered=True` remains exercised
  indirectly via `test_rotary_embedding_full` when impl 2 requires the
  caller to pre-gather (handled internally by the kernel).
- Touch up two stale `apply_rotary_pos_emb shim` comments in
  `kernel{,_atb}.h` that no longer point anywhere.

Verified: `pytest tests/ --devices ascend` → 2278 passed, 1612 skipped
(was 2306 / 1612 — delta is the 28 removed `apply_rotary_pos_emb` cases).
Fold the deleted `test_apply_rotary_pos_emb` / `_3d` cases into a single
`test_rotary_embedding_pre_gathered` that exercises the `pre_gathered`
fast path directly on the `rotary_embedding` overload (no shim).
Parametrize over 2D / 3D query-key layouts, impls 0 and 1 (impl 2 asserts
`!pre_gathered_`), neox / GPT-J styles, fp16 / bf16.  The new
`_build_pre_gathered_cache` helper constructs the `[2*T, head_size]`
wire format that `src/ascend/rotary_embedding/kernel.h` expects —
cos rows 0..T-1, sin rows T..2T-1, both neox-expanded per token.

Coverage: 12 new cases pass (4 skip for `impl=0 + not-neox`, same as the
`test_rotary_embedding_full` skip — V2 only supports `rotaryMode="half"`).

Full rotary suite: 88 passed, 8 skipped (was 80 passed, 4 skipped before
this test was added).
- `src/base/add_rms_norm.h`: `#include <cstddef>` — no `size_t` usage.
- `src/base/rotary_embedding.h`: same.
- `src/ascend/add_rms_norm/kernel_custom.h`: `#include <vector>` — no
  `std::vector` / `std::array` usage.

Build + 355 passed / 8 skipped on Ascend unchanged.
Addresses inline review comments on #66 (reviewer: Ziminli) across all
PR-touched files:

- C4: strip trailing periods from assert messages; lowercase the
  sentence-starting word when it is bare English (e.g. "Ascend ..." →
  "ascend ..."), leave backticked identifiers untouched.
- G4: backtick `RmsNorm` in kernel_custom.h header comment; backtick
  `aclnn` / `cos_sin_cache` / `infini.ops.add_rms_norm(...)` in kernel
  comments that were still running raw text.
- C2: rename `aclrtlaunch_add_rms_norm` / `aclrtlaunch_rms_norm`
  forward-decl parameter names from AscendC internals (`x1, x2, y,
  x_out`) to the base-header semantic names (`input, residual, weight,
  out, residual_out`).  The extern "C" symbol is name-blind so the
  AscendC kernel .cpp can keep its local names — the wrapper .h just
  presents the public contract.
- Pre-gathered rotary test: drop the hardcoded
  `implementation_index=(0, 1)` parametrize, let conftest auto-inject
  and skip impl 2 inline (the impl 2 kernel asserts
  `!pre_gathered_`).

Verified locally (`--gpu-id 3/4/5 --local`):
  test_add_rms_norm.py:      108 passed
  test_rms_norm.py:            72 passed
  test_rotary_embedding.py:    88 passed, 16 skipped (expected:
                                          impl 2 + pre_gathered,
                                          impl 0 + non-neox)
…m order

Addresses Ziminli's comment on `aclrtlaunch_add_rms_norm` forward-decl
(#66 discussion 3115868675 / 3129096521):

- **函数名格式:** the AscendC kernel entry-points `add_rms_norm` /
  `rms_norm` are renamed to `AddRmsNorm` / `RmsNorm`.  The AscendC
  toolchain prepends `aclrtlaunch_` on the symbol regardless of case,
  so the exported names become `aclrtlaunch_AddRmsNorm` /
  `aclrtlaunch_RmsNorm` — matching the base-class names and
  `readability-identifier-naming.FunctionCase = CamelCase`.
  The `NOLINTNEXTLINE(readability-identifier-naming)` shim and the
  "PascalCase rule does not apply" apology comments go away.

- **参数列表顺序 (C2):** reorder parameters to `inputs, attributes,
  outputs`.  Both `.cpp` kernel entry, `KernelAddRmsNorm::Init` /
  `KernelRmsNorm::Init`, and the `extern "C"` forward-decl in
  `kernel_custom.h` are updated together, along with the call sites
  in `operator()`.

- **Variable naming (`.cpp` internals):** `x1/x2/y/x_out` →
  `input/residual/out/residual_out`; `x/y` → `input/out`.  Cascaded
  through member names (`*_gm_`, `*_queue_*`, `*_local`) for
  consistency — internal to each kernel class, no ABI impact.

- **`op_host/*.cpp`:** updated to include the PascalCase generated
  header `aclrtlaunch_AddRmsNorm.h` / `aclrtlaunch_RmsNorm.h` and to
  match the reordered `EXEC_KERNEL_CMD` argument list.

Verified locally with `.ci/run.py --local`:
  test_add_rms_norm.py:      108 passed
  test_rms_norm.py:            72 passed

The AscendC toolchain successfully compiles the PascalCase kernel
entries and generates matching launch headers — the
`aclrtlaunch_<ENTRY>` macro concatenates regardless of case.
/simplify found 4 comment blocks that narrate the rename rationale
rather than encode load-bearing contracts:

- `kernel_custom.h` forward-decl — compress build-system detail
  (`no_workspace_kernel`, `ascendc_library()`) to one line, keep only
  the ABI contract (`aclrtlaunch_<Entry>` is generated by AscendC from
  `op_kernel/`).
- `op_host/<op>.cpp` `EXEC_KERNEL_CMD` — drop "Parameter order follows
  the base class: inputs, attributes, outputs."; the signature itself
  is self-evident.
- `op_kernel/<op>.cpp` kernel entry — drop "Parameters follow the C2
  convention ..." and "`aclrtlaunch_AddRmsNorm` matches the base
  `AddRmsNorm` class name"; these are commit-message material, not
  comments.
`src/CMakeLists.txt:442` referenced `no_workspace_kernel` /
`no_workspace_kernel_build` from inside `if(GENERATE_PYTHON_BINDINGS)`
without checking `WITH_ASCEND`.  Those targets are only created in the
`WITH_ASCEND` block above (244-309), so non-Ascend
`pip install -e .[dev]` failed at CMake configure with `No target
"no_workspace_kernel"` and `dependency target "no_workspace_kernel_build"
does not exist`.

Mirror the gate: `if(WITH_ASCEND AND BUILD_ASCEND_CUSTOM)`.

Verified non-Ascend (`-DWITH_ASCEND=OFF -DWITH_CPU=ON`) and Ascend
(auto-detect) configure both pass.
@zhangyue207 zhangyue207 requested a review from a team April 25, 2026 19:29
@zhangyue207 zhangyue207 marked this pull request as draft April 25, 2026 20:36
zhangyue and others added 2 commits April 27, 2026 15:41
The `--unset=PYTHONPATH` in `add_custom_command` and the `CMAKE_EXE`
forwarding in `build.sh` existed solely so `pip install -e .[dev]`
(default build-isolation) could drive `build.sh` on Ascend.  Drop both
and require `--no-build-isolation` for Ascend installs — the official
workflow.  The `add_custom_command` driver, `BUILD_ASCEND_CUSTOM`
gate, and `detect_soc.cmake` helper are unrelated to build isolation
and stay.
…ant platform iteration

`skip_op_without_platform_impl` iterated `_TORCH_DEVICE_TO_PLATFORMS["cuda"]
= ("nvidia", "metax", "iluvatar")` and called
`active_implementation_indices(p)` for each, regardless of whether `p` was
in this build's `ActiveDevices`.  On an nvidia-only build (the typical
local dev env), the very first call with `"metax"` hit
`DispatchFunc`'s `std::abort()` path — not a Python exception — which
SIGABRT'd the whole pytest worker.  Same crash for hardcoded
`device="npu"` parametrize (`test_rotary_embedding_atb[npu-...]`) on a
non-ascend build.

The platform iteration was redundant: `DeviceTypeFromString` in
`pybind11_utils.h` builds its `TorchNameMap` from `ActiveDevices<>`, so
passing the torch device type directly (`"cuda"`) already resolves to
whichever cuda-family platform is in this configuration.  Replaced the
loop with a single `active_implementation_indices(params["device"])`
call and pre-filtered against `get_available_devices()` so hardcoded
out-of-build parametrize skips before the C++ call.

Drops `_TORCH_DEVICE_TO_PLATFORMS` (only used here).
`pytest.skip(...)` call collapsed onto a single line — line is short
enough after dropping the multi-platform message in the prior commit.
`ruff format --check` now passes.
@zhangyue207
Copy link
Copy Markdown
Collaborator Author

nv

tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-0.5-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
[gw6] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-0.5-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-0.5-a_shape4-b_shape4-c_shape4-None-None-None]
[gw6] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-0.5-a_shape4-b_shape4-c_shape4-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape0-b_shape0-c_shape0-None-None-None]
[gw6] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape0-b_shape0-c_shape0-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape1-b_shape1-c_shape1-None-None-None]
[gw6] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape1-b_shape1-c_shape1-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
[gw6] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
[gw6] [100%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype1-0.01-0.01-True-False--1-1-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]

----------- generated xml file: /workspace/results/test-results.xml ------------
================ 6295 passed, 2548 skipped in 105.57s (0:01:45) ================
========== Summary ==========

moore

FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0-0.5-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0.5--1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0.5--0.5-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0.5-0-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0.5-0.5-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-0.5-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-1--1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-1--0.5-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-1-0-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-1-0.5-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
FAILED tests/test_gemm.py::test_gemm[musa-2-dtype2-0.01-0.01-True-True-1-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
========== 300 failed, 5072 passed, 1108 skipped in 152.21s (0:02:32) ==========
Stage 'test' failed with exit code 1
========== Summary ==========
job moore_gpu failed (exit code 1)

cambricon

[gw1] [ 99%] SKIPPED tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-0.5-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-0.5-a_shape4-b_shape4-c_shape4-None-None-None]
[gw1] [ 99%] SKIPPED tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-0.5-a_shape4-b_shape4-c_shape4-None-None-None]
tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape0-b_shape0-c_shape0-None-None-None]
[gw1] [ 99%] SKIPPED tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape0-b_shape0-c_shape0-None-None-None]
tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape1-b_shape1-c_shape1-None-None-None]
[gw1] [ 99%] SKIPPED tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape1-b_shape1-c_shape1-None-None-None]
tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
[gw1] [ 99%] SKIPPED tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
[gw1] [100%] SKIPPED tests/test_gemm.py::test_gemm[cpu-2-dtype1-0.01-0.01-False-False--0.5-1-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]

----------- generated xml file: /workspace/results/test-results.xml ------------
================ 2500 passed, 3500 skipped in 289.41s (0:04:49) ================
========== Summary ==========

metax

tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-0.5-a_shape4-b_shape4-c_shape4-None-None-None]
[gw1] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-0.5-a_shape4-b_shape4-c_shape4-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape0-b_shape0-c_shape0-None-None-None]
[gw1] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape0-b_shape0-c_shape0-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape1-b_shape1-c_shape1-None-None-None]
[gw1] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape1-b_shape1-c_shape1-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
[gw1] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape2-b_shape2-c_shape2-a_strides2-b_strides2-c_strides2]
tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
[gw1] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape3-b_shape3-c_shape3-a_strides3-b_strides3-c_strides3]
tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape4-b_shape4-c_shape4-None-None-None]
[gw1] [ 99%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-0.5-1-a_shape4-b_shape4-c_shape4-None-None-None]
tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-1--1-a_shape0-b_shape0-c_shape0-None-None-None]
[gw1] [100%] PASSED tests/test_gemm.py::test_gemm[cpu-0-dtype2-0.01-0.01-False-True-1--1-a_shape0-b_shape0-c_shape0-None-None-None]

----------- generated xml file: /workspace/results/test-results.xml ------------
================ 5795 passed, 1548 skipped in 176.37s (0:02:56) ================
========== Summary ==========

ascend

tests/test_swiglu.py::test_swiglu[npu-1-dtype1-0.001-0.001-shape4-None-None-None] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype1-0.001-0.001-shape5-input_strides5-gate_strides5-out_strides5] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype1-0.001-0.001-shape6-None-None-None] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype1-0.001-0.001-shape7-input_strides7-gate_strides7-out_strides7] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape0-None-None-None] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape1-input_strides1-gate_strides1-out_strides1] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape2-None-None-None] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape3-input_strides3-gate_strides3-out_strides3] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape4-None-None-None] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape5-input_strides5-gate_strides5-out_strides5] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape6-None-None-None] PASSED [ 99%]
tests/test_swiglu.py::test_swiglu[npu-1-dtype2-0.01-0.005-shape7-input_strides7-gate_strides7-out_strides7] PASSED [100%]

===================== 2290 passed, 1624 skipped in 20.52s ======================

@zhangyue207 zhangyue207 marked this pull request as ready for review April 28, 2026 04:44
@zhangyue207 zhangyue207 changed the title feat(ascend): op-norm-rope group (re-PR after #72 revert) — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding feat(ascend): op-norm-rope group (re-PR after #72 revert) — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, RotaryEmbedding Apr 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant